# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

#' @include array.R
#' @include chunked-array.R
#' @include scalar.R

call_function <- function(function_name, ..., args = list(...), options = empty_named_list()) {
  assert_that(is.string(function_name))
  assert_that(is.list(options), !is.null(names(options)))

  datum_classes <- c("Array", "ChunkedArray", "RecordBatch", "Table", "Scalar")
  valid_args <- map_lgl(args, ~inherits(., datum_classes))
  if (!all(valid_args)) {
    # Lame, just pick one to report
    first_bad <- min(which(!valid_args))
    stop("Argument ", first_bad, " is of class ", head(class(args[[first_bad]]), 1), " but it must be one of ", oxford_paste(datum_classes, "or"), call. = FALSE)
  }

  compute__CallFunction(function_name, args, options)
}

#' @export
sum.Array <- function(..., na.rm = FALSE) scalar_aggregate("sum", ..., na.rm = na.rm)

#' @export
sum.ChunkedArray <- sum.Array

#' @export
sum.Scalar <- sum.Array

#' @export
mean.Array <- function(..., na.rm = FALSE) scalar_aggregate("mean", ..., na.rm = na.rm)

#' @export
mean.ChunkedArray <- mean.Array

#' @export
mean.Scalar <- mean.Array

#' @export
min.Array <- function(..., na.rm = FALSE) {
  scalar_aggregate("min_max", ..., na.rm = na.rm)$GetFieldByName("min")
}

#' @export
min.ChunkedArray <- min.Array

#' @export
max.Array <- function(..., na.rm = FALSE) {
  scalar_aggregate("min_max", ..., na.rm = na.rm)$GetFieldByName("max")
}

#' @export
max.ChunkedArray <- max.Array

scalar_aggregate <- function(FUN, ..., na.rm = FALSE) {
  a <- collect_arrays_from_dots(list(...))
  if (!na.rm && a$null_count > 0 && (FUN %in% c("mean", "sum"))) {
    # Arrow sum/mean function always drops NAs so handle that here
    # https://issues.apache.org/jira/browse/ARROW-9054
    return(Scalar$create(NA_real_))
  }

  Scalar$create(call_function(FUN, a, options = list(na.rm = na.rm)))
}

collect_arrays_from_dots <- function(dots) {
  # Given a list that may contain both Arrays and ChunkedArrays,
  # return a single ChunkedArray containing all of those chunks
  # (may return a regular Array if there is only one element in dots)
  assert_that(all(map_lgl(dots, is.Array)))
  if (length(dots) == 1) {
    return(dots[[1]])
  }

  arrays <- unlist(lapply(dots, function(x) {
    if (inherits(x, "ChunkedArray")) {
      x$chunks
    } else {
      x
    }
  }))
  ChunkedArray$create(!!!arrays)
}

#' @export
unique.Array <- function(x, incomparables = FALSE, ...) {
  Array$create(call_function("unique", x))
}

#' @export
unique.ChunkedArray <- unique.Array

#' `match` for Arrow objects
#'
#' `base::match()` is not a generic, so we can't just define Arrow methods for
#' it. This function exposes the analogous function in the Arrow C++ library.
#'
#' @param x `Array` or `ChunkedArray`
#' @param table `Array`, `ChunkedArray`, or R vector lookup table.
#' @param ... additional arguments, ignored
#' @return An `int32`-type `Array` of the same length as `x` with the
#' (0-based) indexes into `table`.
#' @export
match_arrow <- function(x, table, ...) UseMethod("match_arrow")

#' @export
match_arrow.default <- function(x, table, ...) match(x, table, ...)

#' @export
match_arrow.Array <- function(x, table, ...) {
  if (!inherits(table, c("Array", "ChunkedArray"))) {
    table <- Array$create(table)
  }
  Array$create(call_function("index_in_meta_binary", x, table))
}

#' @export
match_arrow.ChunkedArray <- function(x, table, ...) {
  if (!inherits(table, c("Array", "ChunkedArray"))) {
    table <- Array$create(table)
  }
  shared_ptr(ChunkedArray, call_function("index_in_meta_binary", x, table))
}

CastOptions <- R6Class("CastOptions", inherit = ArrowObject)

#' Cast options
#'
#' @param safe enforce safe conversion
#' @param allow_int_overflow allow int conversion, `!safe` by default
#' @param allow_time_truncate allow time truncate, `!safe` by default
#' @param allow_float_truncate allow float truncate, `!safe` by default
#'
#' @export
cast_options <- function(safe = TRUE,
                         allow_int_overflow = !safe,
                         allow_time_truncate = !safe,
                         allow_float_truncate = !safe) {
  shared_ptr(CastOptions,
    compute___CastOptions__initialize(allow_int_overflow, allow_time_truncate, allow_float_truncate)
  )
}
